import re

languge_settings = {
    "python": {
        "full_name": "Python",
        "indent": 4,
    },
    "cpp": {
        "full_name": "cpp",
        "indent": 0,
        "main": "int main()",
    },
    "java": {
        "full_name": "Java",
        "indent": 4,
        "main": "public static void main",
    },
    "cs": {
        "full_name": "csharp",
        "indent": 0,
        "main": "public static void Main",
    },
    "php": {
        "full_name": "PHP",
        "indent": 0,
    },
    "ts": {
        "full_name": "TypeScript",
        "indent": 0,
    },
    "js": {"full_name": "JavaScript", "indent": 0},
    "sh": {"full_name": "Bash", "indent": 0},
}


def get_function_name(question: str, lang: str):
    func_lines = [x for x in question.strip().split("\n") if x.strip()]

    if lang.lower() == "python":
        func_idx = [i for i in range(len(func_lines)) if func_lines[i].startswith("def ")][-1]
        func_name = func_lines[func_idx].split("(")[0].strip()
        func_prefix = "\n".join(func_lines[:func_idx])
        return func_name, func_prefix

    func_name = func_lines[-1].split("{")[0].strip()
    func_prefix = "\n".join(func_lines[:-1])
    return func_name, func_prefix


def extract_generation_code(example: str, lang_code: str, verbose: bool = False):
    task_id = example["task_id"]
    output = example.get("output", example.get("gpt_completion"))
    question = example["prompt"].strip()
    setting = languge_settings[lang_code]
    lang = setting["full_name"]
    indent = setting["indent"]

    try:
        code_block: str = re.findall(f"```{lang.lower()}\n(.*?)```", output, re.DOTALL | re.IGNORECASE)[0]
        if verbose:
            print(">>> Task: {}\n{}".format(task_id, code_block))

        # Remove main
        if setting.get("main", None) and setting["main"] in code_block:
            main_start = code_block.index(setting["main"])
            code_block = code_block[:main_start]

        func_name, func_prefix = get_function_name(question, lang)

        try:
            start = code_block.lower().index(func_name.lower())
            indent = 0
            while start - indent >= 0 and code_block[start - indent - 1] == " ":
                indent += 1

            try:
                end = code_block.rindex("\n" + " " * indent + "}")
            except:
                end = len(code_block)
        except:
            start = 0
            try:
                end = code_block.rindex("\n" + " " * indent + "}")
            except:
                end = len(code_block)

        body = code_block[start:end]

        if lang_code.lower() in ["php", "ts", "js"]:
            body += "\n" + " " * indent + "}"

        generation = func_prefix + "\n" + body + "\n"
        example["generation"] = generation

    except Exception as ex:
        print("Failed to extract code block with error `{}`:\n>>> Task: {}\n>>> Output:\n{}".format(ex, task_id, output))
        example["generation"] = example["prompt"] + "\n" + output

    return example


def cleanup_code(code: str, language_type: str = None, dataset: str = None, issft: bool = False, stop_words=[]):
    """
    Cleans up the generated code.
    """

    if language_type.lower() == "python":
        if issft:
            code = _clean_python_code_for_sft(code)
        stop_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint"]
        code = _truncate_code_at_stopwords(code, stop_words)
    elif language_type.lower() == "ts":
        code = _truncate_code_at_stopwords(code, stop_words + ["\nexport", "\nimport", "\nexport default", "\nimport default", "\nconsole.log"])
    # elif language_type.lower() == "java":
    #     code = _truncate_code_for_java(code, stop_words)
    else:
        code = _truncate_code_at_stopwords(code, stop_words)

    return code


def _clean_python_code_for_sft(code):
    code = code.replace("\r", "")
    if "```python" in code:
        code_start_idx = code.index("```python")
        code = code[code_start_idx:].replace("```python", "").strip()
        end_idx = code.find("```") if "```" in code else len(code)
        code = code[:end_idx].strip()

    return code


def _truncate_code_at_stopwords(code, stop_words):
    min_stop_idx = len(code)
    for stop_word in stop_words:
        stop_index = code.find(stop_word)
        if 0 <= stop_index < min_stop_idx:
            min_stop_idx = stop_index
    return code[:min_stop_idx]


def _truncate_code_for_java(code, stop_words):
    main_pattern = "public static void main"
    main_index = code.find(main_pattern)
    if main_index != -1:
        main_start_idx = code.rfind("\n", None, main_index)
        code = code[:main_start_idx] + "\n\n"

    min_stop_idx = len(code)
    for stop_word in stop_words:
        stop_index = code.rfind(stop_word)
        if 0 <= stop_index < min_stop_idx:
            min_stop_idx = stop_index
    return code[:min_stop_idx]
